# 封装成函数，而不是notebook的形式
import numpy as np
from keras.applications import vgg16
from keras.preprocessing.image import load_img, img_to_array
from keras import backend as K
import tensorflow as tf
from scipy.optimize import fmin_l_bfgs_b
import time

from imageio import imwrite,imsave
config = tf.compat.v1.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config = config)
tf.compat.v1.disable_eager_execution()
def preprocess_image(image_path, height=None, width=None):
    height = height if height else 400  # 把如果height = null,为400
    width = width if width else int(width * height / height)
    img = load_img(image_path, target_size=(height, width)) # 加载图片
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0) # (width,height,3) -> (1,width,height,3)
    img = vgg16.preprocess_input(img) # 预处理
    return img
'''
    BGR->RGB
'''
def deprocess_image(x):
    # Remove zero-center by mean pixel
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype('uint8')
    return x

def GetModel(TARGET_IMG,REFERENCE_STYLE_IMG):
    width, height = load_img(TARGET_IMG).size
    img_height = 480
    img_width = int(width * img_height / height)
    print("height,width:",img_height,img_width,load_img(TARGET_IMG).size)
    target_image = K.constant(preprocess_image(TARGET_IMG, height=img_height, width=img_width))
    style_image = K.constant(preprocess_image(REFERENCE_STYLE_IMG, height=img_height, width=img_width))
    # 表示要生成的图像大小
    generated_image = K.placeholder((1, img_height, img_width, 3))
    # 组合3张图像,(3,480,720,3)
    input_tensor = K.concatenate([target_image,
                                  style_image,
                                  generated_image], axis=0)
    print("input_tensor",input_tensor.shape)
    model = vgg16.VGG16(input_tensor=input_tensor,
                        weights='imagenet',
                        include_top=False)
    print(model.summary())
    return model,img_height,img_width,generated_image
'''
    内容损失,均方误差，像素点相减再平方
'''
def content_loss(base, combination):
    return K.sum(K.square(combination - base))

'''
    风格损失
'''
def style_loss(style, combination, height, width):
    # 建立gram矩阵
    def build_gram_matrix(x):
        # permute_dimensions维度转换，相当于把channels放到第一个
        # batch_flatten将一个n阶张量转变为2阶张量，其第一维度保留不变,也就是把像素点flatten
        features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
        # 转置,矩阵乘法
        gram_matrix = K.dot(features, K.transpose(features))
        return gram_matrix
    S = build_gram_matrix(style)
    C = build_gram_matrix(combination)
    channels = 3
    size = height * width
    return K.sum(K.square(S - C)) / (4. * (channels ** 2) * (size ** 2))
'''
    总损失
'''
def total_variation_loss(x,img_height,img_width):
    a = K.square(
        x[:, :img_height - 1, :img_width - 1, :] - x[:, 1:, :img_width - 1, :])
    b = K.square(
        x[:, :img_height - 1, :img_width - 1, :] - x[:, :img_height - 1, 1:, :])
    return K.sum(K.pow(a + b, 1.25))

def GetLoss(model,img_height,img_width,generated_image):
    layers = {l.name: l.output for l in model.layers}
    content_weight = 0.05
    total_variation_weight = 1e-4
    content_layer = 'block4_conv2'
    style_layers = ['block{}_conv2'.format(o) for o in range(1, 6)] # 第二个卷积层(1-5)
    style_weights = [0.1, 0.15, 0.2, 0.25, 0.3]
    loss = K.variable(0.)
    layer_features = layers[content_layer]
    # 查看layer_features的shape,(3, 60, 90, 512)
    target_image_features = layer_features[0, :, :, :] # 目标0
    print("target_image_features", target_image_features.shape)
    combination_features = layer_features[2, :, :, :]  # 生成2
    loss = loss + content_weight * content_loss(target_image_features, combination_features)
    print("loss", content_loss(target_image_features, combination_features))
    for layer_name, sw in zip(style_layers, style_weights):
        layer_features = layers[layer_name]
        style_reference_features = layer_features[1, :, :, :] # 风格图像
        combination_features = layer_features[2, :, :, :]     # 生成图像
        sl = style_loss(style_reference_features, combination_features,
                        height=img_height, width=img_width)
        loss += (sl * sw)
    # add total variation loss
    loss += total_variation_weight * total_variation_loss(generated_image, img_height, img_width)
    return loss

class Evaluator(object):
    def __init__(self, height=None, width=None):
        self.loss_value = None
        self.grads_values = None
        self.height = height
        self.width = width

    def loss(self, x):
        assert self.loss_value is None
        x = x.reshape((1, self.height, self.width, 3))
        outs = fetch_loss_and_grads([x]) # 获得损失和梯度
        loss_value = outs[0]
        grad_values = outs[1].flatten().astype('float64')
        self.loss_value = loss_value
        self.grad_values = grad_values
        return self.loss_value

    def grads(self, x):
        assert self.loss_value is not None
        grad_values = np.copy(self.grad_values)
        self.loss_value = None
        self.grad_values = None
        return grad_values

if __name__ == '__main__':
    # 目标图像
    TARGET_IMG = 'content.png'
    # 风格图像
    REFERENCE_STYLE_IMG = 'style.png'
    model, img_height, img_width, generated_image = GetModel(TARGET_IMG, REFERENCE_STYLE_IMG)
    loss = GetLoss(model, img_height, img_width, generated_image)
    print("generated_image", generated_image.shape)
    evaluator = Evaluator(height=img_height, width=img_width)
    # 返回 variables 在 loss 上的梯度
    grads = K.gradients(loss, generated_image)[0] #(1,height,width,channels) -> (height,width,channels)
    print("grads",grads.shape)
    # grads = tf.GradientTape(loss, generated_image)
    # Function to fetch the values of the current loss and the current gradients
    # 使用K.function()函数打印中间结果
    # 输入的是generated_image
    fetch_loss_and_grads = K.function([generated_image], [loss, grads])
    x = preprocess_image(TARGET_IMG, height=img_height, width=img_width)
    x = x.flatten() # 平坦层
    iterations = 1 # 迭代20次
    result_prefix = 'st_res_' + TARGET_IMG.split('.')[0] # 保存的路径
    for i in range(iterations):
        print('Start of iteration', (i + 1))
        start_time = time.time()
        # fmin_l_bfgs_b优化器
        x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x,
                                         fprime=evaluator.grads, maxfun=20)
        print('Current loss value:', min_val)
        print("info", info)
        if (i + 1) % 5 == 0 or i == 0:
            # Save current generated image only every 5 iterations
            img = x.copy().reshape((img_height, img_width, 3))
            img = deprocess_image(img)  # 图像转换
            fname = result_prefix + '_iter%d.png' % (i + 1)
            # imwrite(fname, img)
            # print('Image saved as', fname)
        end_time = time.time()
        print('Iteration %d completed in %ds' % (i + 1, end_time - start_time))
